// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "src/transform/num_workgroups_from_uniform.h"

#include <memory>
#include <string>
#include <unordered_set>
#include <utility>

#include "src/program_builder.h"
#include "src/sem/function.h"
#include "src/transform/canonicalize_entry_point_io.h"
#include "src/utils/hash.h"

TINT_INSTANTIATE_TYPEINFO(tint::transform::NumWorkgroupsFromUniform);
TINT_INSTANTIATE_TYPEINFO(tint::transform::NumWorkgroupsFromUniform::Config);

namespace tint {
namespace transform {
namespace {
/// Accessor describes the identifiers used in a member accessor that is being
/// used to retrieve the num_workgroups builtin from a parameter.
struct Accessor {
  Symbol param;
  Symbol member;

  /// Equality operator
  bool operator==(const Accessor& other) const {
    return param == other.param && member == other.member;
  }
  /// Hash function
  struct Hasher {
    size_t operator()(const Accessor& a) const {
      return utils::Hash(a.param, a.member);
    }
  };
};
}  // namespace

NumWorkgroupsFromUniform::NumWorkgroupsFromUniform() = default;
NumWorkgroupsFromUniform::~NumWorkgroupsFromUniform() = default;

void NumWorkgroupsFromUniform::Run(CloneContext& ctx,
                                   const DataMap& inputs,
                                   DataMap&) {
  if (!Requires<CanonicalizeEntryPointIO>(ctx)) {
    return;
  }

  auto* cfg = inputs.Get<Config>();
  if (cfg == nullptr) {
    ctx.dst->Diagnostics().add_error(
        diag::System::Transform,
        "missing transform data for " + std::string(TypeInfo().name));
    return;
  }

  const char* kNumWorkgroupsMemberName = "num_workgroups";

  // Find all entry point parameters that declare the num_workgroups builtin.
  std::unordered_set<Accessor, Accessor::Hasher> to_replace;
  for (auto* func : ctx.src->AST().Functions()) {
    // num_workgroups is only valid for compute stages.
    if (func->PipelineStage() != ast::PipelineStage::kCompute) {
      continue;
    }

    for (auto* param : ctx.src->Sem().Get(func)->Parameters()) {
      // Because the CanonicalizeEntryPointIO transform has been run, builtins
      // will only appear as struct members.
      auto* str = param->Type()->As<sem::Struct>();
      if (!str) {
        continue;
      }

      for (auto* member : str->Members()) {
        auto* builtin = ast::GetDecoration<ast::BuiltinDecoration>(
            member->Declaration()->decorations);
        if (!builtin || builtin->builtin != ast::Builtin::kNumWorkgroups) {
          continue;
        }

        // Capture the symbols that would be used to access this member, which
        // we will replace later. We currently have no way to get from the
        // parameter directly to the member accessor expressions that use it.
        to_replace.insert(
            {param->Declaration()->symbol, member->Declaration()->symbol});

        // Remove the struct member.
        // The CanonicalizeEntryPointIO transform will have generated this
        // struct uniquely for this particular entry point, so we know that
        // there will be no other uses of this struct in the module and that we
        // can safely modify it here.
        ctx.Remove(str->Declaration()->members, member->Declaration());

        // If this is the only member, remove the struct and parameter too.
        if (str->Members().size() == 1) {
          ctx.Remove(func->params, param->Declaration());
          ctx.Remove(ctx.src->AST().GlobalDeclarations(), str->Declaration());
        }
      }
    }
  }

  // Get (or create, on first call) the uniform buffer that will receive the
  // number of workgroups.
  const ast::Variable* num_workgroups_ubo = nullptr;
  auto get_ubo = [&]() {
    if (!num_workgroups_ubo) {
      auto* num_workgroups_struct = ctx.dst->Structure(
          ctx.dst->Sym(),
          {ctx.dst->Member(kNumWorkgroupsMemberName,
                           ctx.dst->ty.vec3(ctx.dst->ty.u32()))});
      num_workgroups_ubo = ctx.dst->Global(
          ctx.dst->Sym(), ctx.dst->ty.Of(num_workgroups_struct),
          ast::StorageClass::kUniform,
          ast::DecorationList{ctx.dst->GroupAndBinding(
              cfg->ubo_binding.group, cfg->ubo_binding.binding)});
    }
    return num_workgroups_ubo;
  };

  // Now replace all the places where the builtins are accessed with the value
  // loaded from the uniform buffer.
  for (auto* node : ctx.src->ASTNodes().Objects()) {
    auto* accessor = node->As<ast::MemberAccessorExpression>();
    if (!accessor) {
      continue;
    }
    auto* ident = accessor->structure->As<ast::IdentifierExpression>();
    if (!ident) {
      continue;
    }

    if (to_replace.count({ident->symbol, accessor->member->symbol})) {
      ctx.Replace(accessor, ctx.dst->MemberAccessor(get_ubo()->symbol,
                                                    kNumWorkgroupsMemberName));
    }
  }

  ctx.Clone();
}

NumWorkgroupsFromUniform::Config::Config(sem::BindingPoint ubo_bp)
    : ubo_binding(ubo_bp) {}
NumWorkgroupsFromUniform::Config::Config(const Config&) = default;
NumWorkgroupsFromUniform::Config::~Config() = default;

}  // namespace transform
}  // namespace tint
